{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Matching Networks Using Tensorflow"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "First, we import the libraries, "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "slim = tf.contrib.slim\n",
    "rnn = tf.contrib.rnn"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, we define a class called Matching_network where we define our network. Check the comment on each line of code for understanding. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class Matching_network():\n",
    "    \n",
    "    \n",
    "    #initialize all the variables\n",
    "    def __init__(self, lr, n_way, k_shot, batch_size=32):\n",
    "        \n",
    "        #placeholder for support set\n",
    "        self.support_set_image = tf.placeholder(tf.float32, [None, n_way * k_shot, 28, 28, 1])\n",
    "        self.support_set_label = tf.placeholder(tf.int32, [None, n_way * k_shot, ])\n",
    "        \n",
    "        #placeholder for query set\n",
    "        self.query_image = tf.placeholder(tf.float32, [None, 28, 28, 1])\n",
    "        self.query_label = tf.placeholder(tf.int32, [None, ])\n",
    "\n",
    "\n",
    "    #encoder function for extracting features from the image\n",
    "    def image_encoder(self, image):\n",
    "     \n",
    "    \n",
    "        with slim.arg_scope([slim.conv2d], num_outputs=64, kernel_size=3, normalizer_fn=slim.batch_norm):\n",
    "            #conv1\n",
    "            net = slim.conv2d(image)\n",
    "            net = slim.max_pool2d(net, [2, 2])\n",
    "            \n",
    "            #conv2\n",
    "            net = slim.conv2d(net)\n",
    "            net = slim.max_pool2d(net, [2, 2])\n",
    "            \n",
    "            #conv3\n",
    "            net = slim.conv2d(net)\n",
    "            net = slim.max_pool2d(net, [2, 2])\n",
    "            \n",
    "            #conv4\n",
    "            net = slim.conv2d(net)\n",
    "            net = slim.max_pool2d(net, [2, 2])\n",
    "            \n",
    "        return tf.reshape(net, [-1, 1 * 1 * 64])\n",
    "    \n",
    "       \n",
    "    #embedding function for extracting support set embeddings\n",
    "    def g(self, x_i):\n",
    "\n",
    "        forward_cell = rnn.BasicLSTMCell(32)\n",
    "        backward_cell  = rnn.BasicLSTMCell(32)\n",
    "        outputs, state_forward, state_backward = rnn.static_bidirectional_rnn(forward_cell, backward_cell, x_i, dtype=tf.float32)\n",
    "\n",
    "        return tf.add(tf.stack(x_i), tf.stack(outputs))\n",
    "\n",
    "    \n",
    "    #embedding function for extracting query set embeddings\n",
    "    def f(self, XHat, g_embedding):\n",
    "        cell = rnn.BasicLSTMCell(64)\n",
    "        prev_state = cell.zero_state(self.batch_size, tf.float32) \n",
    "\n",
    "        for step in xrange(self.processing_steps):\n",
    "            output, state = cell(XHat, prev_state)\n",
    "            \n",
    "            h_k = tf.add(output, XHat) \n",
    "\n",
    "            content_based_attention = tf.nn.softmax(tf.multiply(prev_state[1], g_embedding))  \n",
    "            \n",
    "            r_k = tf.reduce_sum(tf.multiply(content_based_attention, g_embedding), axis=0)      \n",
    "\n",
    "            prev_state = rnn.LSTMStateTuple(state[0], tf.add(h_k, r_k))\n",
    "\n",
    "        return output\n",
    "\n",
    "    #cosine similarity function for calculating cosine similarity between support set and query set embeddings\n",
    "    def cosine_similarity(self, target, support_set):\n",
    "        target_normed = target\n",
    "        sup_similarity = []\n",
    "        for i in tf.unstack(support_set):\n",
    "            i_normed = tf.nn.l2_normalize(i, 1) \n",
    "            similarity = tf.matmul(tf.expand_dims(target_normed, 1), tf.expand_dims(i_normed, 2)) \n",
    "            sup_similarity.append(similarity)\n",
    "\n",
    "        return tf.squeeze(tf.stack(sup_similarity, axis=1)) \n",
    "    \n",
    "\n",
    "    def train(self, support_set_image, support_set_label, query_image):    \n",
    "        \n",
    "        #encode the features of query set images using our image encoder\n",
    "        query_image_encoded = self.image_encoder(query_image)  \n",
    "        \n",
    "        #encode the features of support set images using our image encoder\n",
    "        support_set_image_encoded = [self.image_encoder(i) for i in tf.unstack(support_set_image, axis=1)]\n",
    "        \n",
    "        #generate support set embeddings using our embedding function g\n",
    "        g_embedding = self.g(support_set_image_encoded)    \n",
    "        \n",
    "        #generate query set embeddings using our embedding function f\n",
    "        f_embedding = self.f(query_image_encoded, g_embedding)    \n",
    "\n",
    "        #calculate the cosine similarity between both of these embeddings\n",
    "        embeddings_similarity = self.cosine_similarity(f_embedding, g_embedding) \n",
    "        \n",
    "        #perform attention over the embedding similarity\n",
    "        attention = tf.nn.softmax(embeddings_similarity)\n",
    "        \n",
    "        #now predict query set label by multiplying attention matrix with one hot encoded support set labels\n",
    "        y_hat = tf.matmul(tf.expand_dims(attention, 1), tf.one_hot(support_set_label, self.n_way))\n",
    "        \n",
    "        #get the probabilities \n",
    "        probabilities = tf.squeeze(y_hat)   \n",
    "        \n",
    "        #select the index which has the highest probability as a class of query image\n",
    "        predictions = tf.argmax(self.probabilities, 1)\n",
    "        \n",
    "        #we use softmax cross entropy loss as our loss function\n",
    "        loss_function = tf.losses.sparse_softmax_cross_entropy(label, self.probabilities)\n",
    "        \n",
    "        #we minimize the loss using adam optimizer\n",
    "        tf.train.AdamOptimizer(self.lr).minimize(self.loss_op)\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [default]",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
